Parte 10: Aprendizaje Federado con Agregación Encriptada de Gradientes

En las últimas secciones hemos aprendido sobre el cómputo encriptado construyendo varios programas simples. En esta sección, regresaremos al Demo de Aprendizaje Federado de la parte 4, donde teníamos un "agregador confiable" quien es el responsable de promediar las actualizaciones de los modelos de varios trabajadores.

Ahora vamos a usar nuestras nuevas herramientas de cómputo encriptado para dispensar este agregador confiable ya que no es ideal tenerlo porque asume que podemos encontrar a alguien lo suficientemente confiable para que tenga acceso a esta información sensible. Esto no siempre es el caso.

Por lo tanto, en este notebook mostraremos cómo podemos usar la computación segura multi-parte (CSMP) para realizar una agregación segura de tal manera que necesitemos un "agregador seguro".

Autores:

Traductores:

Sección 1: Aprendizaje Federado Normal

Primero, aquí hay código que realiza un aprendizaje federado clásico en el conjunto de datos Boston Housing. Esta sección del código puede desglosarse en varias secciones.

Configuración


In [ ]:
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

class Parser:
    """Parámetros para el entrenamiento"""
    def __init__(self):
        self.epochs = 10
        self.lr = 0.001
        self.test_batch_size = 8
        self.batch_size = 8
        self.log_interval = 10
        self.seed = 1
    
args = Parser()

torch.manual_seed(args.seed)
kwargs = {}

Cargar los Datos


In [ ]:
with open('../data/BostonHousing/boston_housing.pickle','rb') as f:
    ((X, y), (X_test, y_test)) = pickle.load(f)

X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float()
# preprocesamiento
mean = X.mean(0, keepdim=True)
dev = X.std(0, keepdim=True)
mean[:, 3] = 0. # la columna 3 es binaria
dev[:, 3] = 1.  # así que no la estandarizamos
X = (X - mean) / dev
X_test = (X_test - mean) / dev
train = TensorDataset(X, y)
test = TensorDataset(X_test, y_test)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)

Estructura de la Red Neuronal


In [ ]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(13, 32)
        self.fc2 = nn.Linear(32, 24)
        self.fc3 = nn.Linear(24, 1)

    def forward(self, x):
        x = x.view(-1, 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
optimizer = optim.SGD(model.parameters(), lr=args.lr)

Enganche de Pytorch


In [ ]:
import syft as sy

hook = sy.TorchHook(torch)
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
james = sy.VirtualWorker(hook, id="james")

compute_nodes = [bob, alice]

Mandar los datos a los trabajadores
Usualmente ya lo tendrían, eso sólo es para el demo.


In [ ]:
train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    train_distributed_dataset.append((data, target))

Función de Entrenamiento


In [ ]:
def train(epoch):
    model.train()
    for batch_idx, (data,target) in enumerate(train_distributed_dataset):
        worker = data.location
        model.send(worker)

        optimizer.zero_grad()
        # actualiza el modelo
        pred = model(data)
        loss = F.mse_loss(pred.view(-1), target)
        loss.backward()
        optimizer.step()
        model.get()
            
        if batch_idx % args.log_interval == 0:
            loss = loss.get()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * data.shape[0], len(train_loader),
                       100. * batch_idx / len(train_loader), loss.item()))

Función para Pruebas


In [ ]:
def test():
    model.eval()
    test_loss = 0
    for data, target in test_loader:
        output = model(data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # suma la pérdida
        pred = output.data.max(1, keepdim=True)[1] # obtén el índice del máximo de la probabilidad logarítmica
        
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}\n'.format(test_loss))

Entrenando el Modelo


In [ ]:
import time

In [ ]:
t = time.time()

for epoch in range(1, args.epochs + 1):
    train(epoch)

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Calculando el Desempeño


In [ ]:
test()

Sección 2: Añadiendo la Agregación Encriptada

Ahora vamos a modificar este ejemplo sutilmente para agregar los gradientes de manera encriptada. La diferencia principal está en las líneas 1 o 2 del código en la función train(), que mostraremos. Por el momento, vamos a reprocesar nuestros datos e inicializar el modelo para bob y alice.


In [ ]:
remote_dataset = (list(),list())

train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))

def update(data, target, model, optimizer):
    model.send(data.location)
    optimizer.zero_grad()
    pred = model(data)
    loss = F.mse_loss(pred.view(-1), target)
    loss.backward()
    optimizer.step()
    return model

bobs_model = Net()
alices_model = Net()

bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)
alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)

models = [bobs_model, alices_model]
params = [list(bobs_model.parameters()), list(alices_model.parameters())]
optimizers = [bobs_optimizer, alices_optimizer]

Construyendo la Lógica de Entrenamiento

La única diferencia real está dentro del método de entrenamiento. Vamos a ver esto paso por paso.

Parte A: Entrenamiento:


In [ ]:
# esto selecciona el batch que entrenaremos
data_index = 0
# actualiza los modelos remotos
# podríamos iterar esto múltiples veces antes de proceder, pero sólo vamos a hacer una iteración por trabajador aquí
for remote_index in range(len(compute_nodes)):
    data, target = remote_dataset[remote_index][data_index]
    models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])

Parte B: Agregación Encriptada


In [ ]:
# crea una lista donde depositaremos nuestro modelo modelo promedio encriptado
new_params = list()

In [ ]:
# itera sobre cada parámetro
for param_i in range(len(params[0])):

    # para cada trabajador
    spdz_params = list()
    for remote_index in range(len(compute_nodes)):
        
        # selecciona el parámetro idéntico para cada trabajador y copia
        copy_of_parameter = params[remote_index][param_i].copy()
        
        # como la CSMP sólo puede trabajar con enteros (sin puntos flotantes), necesitamos
        # utilizar enteros para guardar la información decimal. En otras palabras, necesitamos
        # usar una codificación con precisión fija.
        fixed_precision_param = copy_of_parameter.fix_precision()
        
        # ahora encriptamos esto en una máquina remota. Nota que
        # fixed_precision_param ya es un puntero. Entonces, cuando
        # llamamos share encripta los datos a los que se apunta. Esto
        # regresa un puntero al objeto secreto compartido en el CMP,
        # que necesitamos tomar.
        encrypted_param = fixed_precision_param.share(bob, alice, crypto_provider=james)
        
        # ahora tomamos el puntero
        param = encrypted_param.get()
        
        # guarda el parámetro para promediarlo con el mismo parámetro de
        # los otros trabajadores
        spdz_params.append(param)

    # promedia params con múltiples trabajadores, tómalos a la máquina local
    # desencripta y decodifica (de la precisión fija) al número de punto flotante
    new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2
    
    # guarda en nuevo parámetro promediado
    new_params.append(new_param)

Parte C: Limpieza


In [ ]:
with torch.no_grad():
    for model in params:
        for param in model:
            param *= 0

    for model in models:
        model.get()

    for remote_index in range(len(compute_nodes)):
        for param_index in range(len(params[remote_index])):
            params[remote_index][param_index].set_(new_params[param_index])

¡Ahora lo juntamos!

Y ahora que conocemos cada paso, podemos juntarlo en un ciclo de entrenamiento.


In [ ]:
def train(epoch):
    for data_index in range(len(remote_dataset[0])-1):
        # actualiza los modelos remotos
        for remote_index in range(len(compute_nodes)):
            data, target = remote_dataset[remote_index][data_index]
            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])

        # agregación encriptada
        new_params = list()
        for param_i in range(len(params[0])):
            spdz_params = list()
            for remote_index in range(len(compute_nodes)):
                spdz_params.append(params[remote_index][param_i].copy().fix_precision().share(bob, alice, crypto_provider=james).get())

            new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2
            new_params.append(new_param)

        # limpieza
        with torch.no_grad():
            for model in params:
                for param in model:
                    param *= 0

            for model in models:
                model.get()

            for remote_index in range(len(compute_nodes)):
                for param_index in range(len(params[remote_index])):
                    params[remote_index][param_index].set_(new_params[param_index])

In [ ]:
def test():
    models[0].eval()
    test_loss = 0
    for data, target in test_loader:
        output = models[0](data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # suma la pérdida
        pred = output.data.max(1, keepdim=True)[1] # obtén el índice del máximo de la probabilidad logarítmica
        
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}\n'.format(test_loss))

In [ ]:
t = time.time()

for epoch in range(args.epochs):
    print(f"Epoch {epoch + 1}")
    train(epoch)
    test()

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

!Felicitaciones! - !Es hora de unirte a la comunidad!

¡Felicitaciones por completar esta parte del tutorial! Si te gustó y quieres unirte al movimiento para preservar la privacidad, propiedad descentralizada de IA y la cadena de suministro de IA (datos), puedes hacerlo de las ¡siguientes formas!

Dale una estrella a PySyft en GitHub

La forma más fácil de ayudar a nuestra comunidad es por darle estrellas a ¡los repositorios de Github! Esto ayuda a crear consciencia de las interesantes herramientas que estamos construyendo.

¡Únete a nuestro Slack!

La mejor manera de mantenerte actualizado con los últimos avances es ¡unirte a la comunidad! Tú lo puedes hacer llenando el formulario en http://slack.openmined.org

¡Únete a un proyecto de código!

La mejor manera de contribuir a nuestra comunidad es convertirte en un ¡contribuidor de código! En cualquier momento puedes ir al Github Issues de PySyft y filtrar por "Proyectos". Esto mostrará todos los tiquetes de nivel superior dando un resumen de los proyectos a los que ¡te puedes unir! Si no te quieres unir a un proyecto, pero quieres hacer un poco de código, también puedes mirar más mini-proyectos "de una persona" buscando por Github Issues con la etiqueta "good first issue".

Donar

Si no tienes tiempo para contribuir a nuestra base de código, pero quieres ofrecer tu ayuda, también puedes aportar a nuestro Open Collective". Todas las donaciones van a nuestro web hosting y otros gastos de nuestra comunidad como ¡hackathons y meetups!

OpenMined's Open Collective Page


In [ ]: